# ------------------------------------------------------------------------------
# Creates ensembling sets for fold models
# Updates author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bash 01_create_ensembles.sh {downsample} {split}
# ------------------------------------------------------------------------------

# Seeding ----------------------------------------------------------------------
set.seed(1)

# Libraries --------------------------------------------------------------------
message("Loading libraries...")
library(yaml)
library(data.table)
library(glue)
library(Matrix)
library(doParallel)
library(glmnet)
library(optparse)
library(here)
library(testit)
library(tidyverse)

u <- modules::use(here("lib", "util.R"))

# Command Line Args ------------------------------------------------------------
message("Parsing command line args...")
arg_config <- list(
  make_option("--split", type = "character")
)
arg_parser <- OptionParser(option_list = arg_config)
arg_list <- parse_args(arg_parser)
split <- arg_list$split

# Directories ------------------------------------------------------------------
message("Establishing directories...")
paths <- read_yaml(here("lib", "filepaths.yml"))
load_dir <- file.path(paths$modeling$dir, "cohorts", arg_list$split)
save_dir <- file.path(paths$modeling$oos, "cohorts", arg_list$split)
build_models_dir <- here::here("code", "03_analysis", "01_build_models")

# Load Data --------------------------------------------------------------------
message("Loading data...")
cohort_name <- arg_list$cohort
ids <- readRDS(file.path(load_dir, "train_cohort.rds"))

# Ensemble folds ---------------------------------------------------------------
message("Creating ensemble sets within folds...")
ensemble <- tibble(ptid = double(), in_ensemble = logical())
for (i in 1:5) {
  fold <- ids %>% filter(train_fold == i)
  patients <- unique(fold$ptid)
  ensemble_ids <- sample(patients, ceiling(0.06 * length(patients)))
  ensemble_fold <- tibble(ptid = patients, in_ensemble = patients %in% ensemble_ids)
  ensemble <- ensemble %>% rbind(ensemble_fold)
}

# Save -------------------------------------------------------------------------
downsamples <- c("non_downsampled", "tested", "ds_mace", "ds_test")
for(ds in downsamples){
  save_fp <- file.path(save_dir, glue("train_cohort_{ds}.rds"))
  message("Saving ", ds, " cohort to ", save_fp, "...")
  ds_df <- readRDS(file.path(load_dir, glue("train_cohort_{ds}.rds"))) %>%
    u$safe_left_join(ensemble)
  write_rds(ds_df, save_fp)
}

message("Done.")
